{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "G1w5n0W02uiH"
   },
   "source": [
    "# Unsupervised learning with Autoencoder\n",
    "\n",
    "We first play with MNIST dataset and pieces of code seen during the course."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "CFac0UPC2uiI"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "s7XyL-JX2uiM"
   },
   "source": [
    "## Loading MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "9yAsJijZ2uiM"
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "print('Using gpu: %s ' % torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "6koHB8Z_2uiQ"
   },
   "outputs": [],
   "source": [
    "# to be modified if not on colab\n",
    "root_dir = './data/MNIST/'\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST(root_dir, train=True, download=True, transform=transforms.ToTensor()),\n",
    "    batch_size=256, shuffle=True)\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST(root_dir, train=False, download=True, transform=transforms.ToTensor()),\n",
    "    batch_size=10, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "s1M2kLzn2uiT"
   },
   "source": [
    "## Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "wugVCUrI2uiT"
   },
   "outputs": [],
   "source": [
    "def to_img(x):\n",
    "    x = x.cpu().data.numpy()\n",
    "    x = 0.5 * (x + 1)\n",
    "    x = np.clip(x, 0, 1)\n",
    "    x = x.reshape([-1, 28, 28])\n",
    "    return x\n",
    "\n",
    "def plot_reconstructions(model):\n",
    "    \"\"\"\n",
    "    Plot 10 reconstructions from the test set. The top row is the original\n",
    "    digits, the bottom is the decoder reconstruction.\n",
    "    The middle row is the encoded vector.\n",
    "    The encoder is called by model.encoder\n",
    "    The decoder is called by model.decoder\n",
    "    \"\"\"\n",
    "    # encode then decode\n",
    "    data, _ = next(iter(test_loader))\n",
    "    data = data.view([-1, 784])\n",
    "    data.requires_grad = False\n",
    "    data = data.to(device)\n",
    "    true_imgs = data\n",
    "    encoded_imgs = model.encoder(data)\n",
    "    decoded_imgs = model.decoder(encoded_imgs)\n",
    "    \n",
    "    true_imgs = to_img(true_imgs)\n",
    "    decoded_imgs = to_img(decoded_imgs)\n",
    "    encoded_imgs = encoded_imgs.cpu().data.numpy()\n",
    "    \n",
    "    n = 10\n",
    "    plt.figure(figsize=(20, 4))\n",
    "    for i in range(n):\n",
    "        # display original\n",
    "        ax = plt.subplot(3, n, i + 1)\n",
    "        plt.imshow(true_imgs[i])\n",
    "        plt.gray()\n",
    "        ax.get_xaxis().set_visible(False)\n",
    "        ax.get_yaxis().set_visible(False)\n",
    "        \n",
    "        ax = plt.subplot(3, n, i + 1 + n)\n",
    "        plt.imshow(encoded_imgs[i].reshape(-1,4))\n",
    "        plt.gray()\n",
    "        ax.get_xaxis().set_visible(False)\n",
    "        ax.get_yaxis().set_visible(False)\n",
    "\n",
    "        # display reconstruction\n",
    "        ax = plt.subplot(3, n, i + 1 + n + n)\n",
    "        plt.imshow(decoded_imgs[i])\n",
    "        plt.gray()\n",
    "        ax.get_xaxis().set_visible(False)\n",
    "        ax.get_yaxis().set_visible(False)\n",
    "    \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "dRAI3MaO2uiW"
   },
   "source": [
    "## Simple Auto-Encoder\n",
    "\n",
    "We'll start with the simplest autoencoder: a single, fully-connected layer as the encoder and decoder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "GApDwy6e2uiW"
   },
   "outputs": [],
   "source": [
    "class AutoEncoder(nn.Module):\n",
    "    def __init__(self, input_dim, encoding_dim):\n",
    "        super(AutoEncoder, self).__init__()\n",
    "        self.encoder = nn.Linear(input_dim, encoding_dim)\n",
    "        self.decoder = nn.Linear(encoding_dim, input_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        encoded = F.relu(self.encoder(x))\n",
    "        decoded = self.decoder(encoded)\n",
    "        return decoded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "oUbMKILV2uiZ"
   },
   "outputs": [],
   "source": [
    "input_dim = 784\n",
    "encoding_dim = 64\n",
    "\n",
    "model = AutoEncoder(input_dim, encoding_dim)\n",
    "model = model.to(device)\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "loss_fn = torch.nn.MSELoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ZaNO1R_G2uid"
   },
   "source": [
    "Why did we take 784 as input dimension?\n",
    "\n",
    "To find the learning rate, see the documentation for [Adam optimizer](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "e3Zh8eQ_2uid"
   },
   "outputs": [],
   "source": [
    "def train_model(model,loss_fn,data_loader=None,epochs=1,optimizer=None):\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        for batch_idx, (data, _) in enumerate(train_loader):\n",
    "            data = data.view([-1, 784]).to(device)\n",
    "            optimizer.zero_grad()\n",
    "            output = model(data)\n",
    "            loss = loss_fn(output, data)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            if batch_idx % 100 == 0:\n",
    "                print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                    epoch, batch_idx * len(data), len(data_loader.dataset),\n",
    "                    100. * batch_idx / len(data_loader), loss.data.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "dybe6kLJ2uig",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_model(model, loss_fn, data_loader=train_loader, epochs=10, optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "NOPJgAUh2uij"
   },
   "outputs": [],
   "source": [
    "plot_reconstructions(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "xhCiHP7P2uil"
   },
   "source": [
    "## 1. Exercise: Stacked Autoencoder\n",
    "\n",
    "Now you will code an autoencoder where both the encoder and the decoder are multilayer perceptron (MLP). You can take for the encoder a first hidden layer with dimension 128, a second one with dimension 64 and then the code of dimension 32. For the decoder, you can take the same sequence of dimensions in reverse order."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "SrTmxg7z2uim"
   },
   "outputs": [],
   "source": [
    "class DeepAutoEncoder(nn.Module):\n",
    "    def __init__(self, input_dim, encoding_dim):\n",
    "        super(DeepAutoEncoder, self).__init__()\n",
    "        #\n",
    "        # your code here\n",
    "        #\n",
    "        \n",
    "    def forward(self, x):\n",
    "        #\n",
    "        # your code here\n",
    "        #"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Bm_n3sI92uiq"
   },
   "outputs": [],
   "source": [
    "input_dim = 784\n",
    "encoding_dim = 32\n",
    "\n",
    "model = DeepAutoEncoder(input_dim, encoding_dim)\n",
    "model = model.to(device)\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "loss_fn = torch.nn.MSELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8VNi4vtk2uit",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_model(model, loss_fn,data_loader=train_loader,epochs=10,optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "nraDkOnC2uiv"
   },
   "outputs": [],
   "source": [
    "plot_reconstructions(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "XIWLi9Ps2uix"
   },
   "source": [
    "Replace the `MSELoss` with a `BCEWithLogitsLoss` for each pixel. Note the unusual use of `BCEWithLogitsLoss`! You can have a look at the definition of [Cross Entropy](https://en.wikipedia.org/wiki/Cross_entropy)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "9S8qgLsO2uix"
   },
   "source": [
    "## 2. Optional\n",
    "\n",
    "At this stage, you can code the interpolation described in the lesson to obtain:\n",
    "\n",
    "![](https://raw.githubusercontent.com/dataflowr/slides/master/images/module9/interp_AE.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "9VCJLQXE2ujS"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Z3sm7Pi62ujU"
   },
   "source": [
    "# 3. Exercise: Implement a denoising AE:\n",
    "\n",
    "\n",
    "Use previous code and with minimal modifications, transform your AE in a denoising AE. Now, you first apply some noise to your input and try to recover the original data at the output. For the noise, you can add some random noise or erase some of the pixels. In this last case, you should obtain something like: \n",
    "\n",
    "![](https://raw.githubusercontent.com/dataflowr/slides/master/images/module9/denoising_AE.png)\n",
    "\n",
    "The first line corresponds to the original digit, the second line to the noisy version of the digit given as input to your network, the third line is the associated code and the last line is the denoised digit obtained by your decoder from the code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "isYFcQFS2ujU"
   },
   "outputs": [],
   "source": [
    "# You need first to modify the training process by adding noise to your data\n",
    "# Hint if you want to erase pixels: https://stackoverflow.com/questions/49216615/is-there-an-efficient-way-to-create-a-random-bit-mask-in-pytorch\n",
    "def train_denoiser(model,loss_fn,data_loader=None,epochs=1,optimizer=None, noise=0.1):\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "      for batch_idx, (data, _) in enumerate(train_loader):\n",
    "        #\n",
    "        # your code here to create the noisy_data\n",
    "        #\n",
    "        optimizer.zero_grad()\n",
    "        output = model(noisy_data)\n",
    "        loss = loss_fn(output, data)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "            \n",
    "        if batch_idx % 100 == 0:\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                epoch, batch_idx * len(data), len(data_loader.dataset),\n",
    "                100. * batch_idx / len(data_loader), loss.data.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "EBkQbfVl2ujX"
   },
   "outputs": [],
   "source": [
    "input_dim = 784\n",
    "encoding_dim = 32\n",
    "\n",
    "model = DeepAutoEncoder(input_dim, encoding_dim)\n",
    "model = model.to(device)\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "loss_fn = torch.nn.BCEWithLogitsLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "KlYqlT822ujZ"
   },
   "outputs": [],
   "source": [
    "train_denoiser(model, loss_fn,data_loader=train_loader,epochs=10,optimizer=optimizer, noise=0.8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LGixZVJF2ujW"
   },
   "outputs": [],
   "source": [
    "# Now you need to modify the plot function\n",
    "def plot_denoising(model, noise=0.1):\n",
    "    \"\"\"\n",
    "    Plot 10 reconstructions from the test set. The top row is the original\n",
    "    digits, , the second row is the noisy digits, \n",
    "    the third row is the encoded vector and\n",
    "    the bottom is the decoder reconstruction.\n",
    "    \"\"\"\n",
    "    # encode then decode\n",
    "    data, _ = next(iter(test_loader))\n",
    "    #\n",
    "    # your code here to compute\n",
    "    # noisy_data\n",
    "    # encoded_imgs\n",
    "    # decoded_imgs\n",
    "    #\n",
    "    true_imgs = to_img(data)\n",
    "    noisy_imgs = to_img(noisy_data)\n",
    "    decoded_imgs = to_img(decoded_imgs)\n",
    "    encoded_imgs = encoded_imgs.cpu().data.numpy()\n",
    "    \n",
    "    n = 10\n",
    "    plt.figure(figsize=(20, 4))\n",
    "    for i in range(n):\n",
    "        # display original\n",
    "        ax = plt.subplot(4, n, i + 1)\n",
    "        plt.imshow(true_imgs[i])\n",
    "        plt.gray()\n",
    "        ax.get_xaxis().set_visible(False)\n",
    "        ax.get_yaxis().set_visible(False)\n",
    "        \n",
    "        # display corrupted original\n",
    "        ax = plt.subplot(4, n, i + 1 +n)\n",
    "        plt.imshow(noisy_imgs[i])\n",
    "        plt.gray()\n",
    "        ax.get_xaxis().set_visible(False)\n",
    "        ax.get_yaxis().set_visible(False)\n",
    "        \n",
    "        # display code\n",
    "        ax = plt.subplot(4, n, i + 1 + 2*n)\n",
    "        plt.imshow(encoded_imgs[i].reshape(-1,4))\n",
    "        plt.gray()\n",
    "        ax.get_xaxis().set_visible(False)\n",
    "        ax.get_yaxis().set_visible(False)\n",
    "\n",
    "        # display reconstruction\n",
    "        ax = plt.subplot(4, n, i + 1 +  3*n)\n",
    "        plt.imshow(decoded_imgs[i])\n",
    "        plt.gray()\n",
    "        ax.get_xaxis().set_visible(False)\n",
    "        ax.get_yaxis().set_visible(False)\n",
    "    \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "RFLDJ6f_2ujb"
   },
   "outputs": [],
   "source": [
    "plot_denoising(model, noise=0.8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ddKvatCh2ujn"
   },
   "source": [
    "# 4. Optional: how to deal with convolutions?\n",
    "\n",
    "Hint: start by decreasing the size of your image with `Conv2d` by using a `stride` like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conv = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1, stride=2)\n",
    "x = torch.randn(2, 8, 64, 64)\n",
    "y = conv(x)\n",
    "y.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now use [transposed convolution](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html) (or [deconvolution](https://distill.pub/2016/deconv-checkerboard/)) with the same parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "convt = nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=3, padding=1, stride=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "convt(y).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To get the same size as `x`, play with `output_padding`.\n",
    "\n",
    "Now, you have all the tools to build a convolutional autoencoder!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "include_colab_link": false,
   "name": "09_AE_NoisyAE.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
